feat: add Qwen3.5 35B A3B model support on TRN2#48
Open
YantaoShen wants to merge 4 commits intoaws-neuron:mainfrom
Open
feat: add Qwen3.5 35B A3B model support on TRN2#48YantaoShen wants to merge 4 commits intoaws-neuron:mainfrom
YantaoShen wants to merge 4 commits intoaws-neuron:mainfrom
Conversation
vgene
requested changes
Apr 3, 2026
| all_logits = local_logits | ||
|
|
||
| # Argmax on CPU | ||
| next_id = all_logits.argmax(dim=-1, keepdim=True).to(dtype=torch.int) # (B, 1) |
Contributor
There was a problem hiding this comment.
Let's make sure this is properly fixed before merging
…-topk
Replaces the CPU argmax path in _sample_token with a single greedy_sampling
kernel that does RMSNorm + lm_head matmul + all_gather(full logits) + global
topk entirely on device. Per-step DtoH is now (B,) uint32 = 4 bytes instead
of (B, vocab_per_device) f32 = ~248 KB, and gloo all_gather + torch argmax
on CPU are both eliminated.
Kernel graph is gather-then-topk, not the usual topk-then-gather-then-index:
neuronx-cc 2.23 miscompiles topk when all_gather is downstream in the same
kernel for certain vocab_per_device sizes (including 62080 used at TP=4),
producing wrong token IDs. Keeping topk strictly downstream of all_gather
sidesteps this, and taking topk over the gathered full-vocab logits directly
returns the global winner ID -- no rank-offset arithmetic, no dynamic index.
- kernels/sampling.py: replace compute_logits with greedy_sampling
- qwen3_5.py: compile greedy_sampling, swap (B, vpd) f32 logits buffers
for (B,) uint32 next_id buffers, simplify _sample_token to a single
DtoH + reshape
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Issue #, if available:
Description of changes:
feat: add Qwen3.5 35B A3B model support on TRN2
Add inference support for Qwen3.5-35B-A3B (MoE) on AWS Trainium2.
By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice.